/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.dllib.nn.mkldnn

import com.intel.analytics.bigdl.mkl.DataType
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity
import com.intel.analytics.bigdl.dllib.tensor.Tensor
import com.intel.analytics.bigdl.dllib.utils.{Log4Error, T}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

private[mkldnn] class ReorderManager() (implicit owner: MemoryOwner) {
  // (MemoryFormatId, TargetFormat) -> Reorder
  val reorders = mutable.HashMap[(Int, MemoryData), ReorderMemory]()
  // ReorderId -> RefCount
  val refCounts = mutable.HashMap[Int, Int]()
  val useCounts = mutable.HashMap[Int, Int]()

  private var runtime: MklDnnRuntime = _

  def register(from: MemoryData, to: MemoryData): Unit = {
    Log4Error.invalidInputError(runtime != null, "Please call setRuntime first")
    val mId = System.identityHashCode(from)
    if (needReorder(from, to)) {
      if (reorders.contains((mId, to))) {
        refCounts(System.identityHashCode(reorders((mId, to)))) += 1
      } else {
        val reorder = ReorderMemory(to)
        reorder.setRuntime(runtime)
        reorder.initFwdPrimitives(Array(from), Phase.InferencePhase)
        reorders((mId, to)) = reorder
        val reorderId = System.identityHashCode(reorder)
        refCounts(reorderId) = 1
        useCounts(reorderId) = 0
      }
    }
  }

  def setRuntime(runtime: MklDnnRuntime): Unit = {
    this.runtime = runtime
  }

  def infer(from: Array[MemoryData], to: Array[MemoryData], output: Activity)
  : Activity = {
    if (from.length == 1) {
      Log4Error.invalidInputError(output.isTensor, "output activity should be a tensor")
      inferTensor(from(0), to(0), output.asInstanceOf[Tensor[Float]])
    } else {
      Log4Error.invalidInputError(output.toTable.length() == from.length,
        "output activity length doesn't match")
      val outputTable = T()
      var i = 0
      while(i < from.length) {
        outputTable(i + 1) = inferTensor(from(i), to(i), output.toTable(i + 1))
        i += 1
      }
      outputTable
    }
  }

  private def inferTensor(from: MemoryData, to : MemoryData, output: Tensor[Float])
  : Tensor[Float] = {
    val mId = System.identityHashCode(from)
    if (reorders.contains((mId, to))) {
      val reorder = reorders((mId, to))
      val reorderId = System.identityHashCode(reorder)
      val result = if (useCounts(reorderId) == 0) {
        reorder.forward(output).asInstanceOf[Tensor[Float]]
      } else {
        reorder.output.asInstanceOf[Tensor[Float]]
      }
      useCounts(reorderId) += 1
      if (useCounts(reorderId) == refCounts(reorderId)) {
        useCounts(reorderId) = 0
      }
      result
    } else {
      output
    }
  }

  private def needReorder(from: MemoryData, to: MemoryData): Boolean = {
    from match {
      case h: HeapData =>
        to match {
          case hh: HeapData => h.layout != hh.layout
          case nn: NativeData => true
          case _ =>
            Log4Error.invalidInputError(false, s"memory format $to is not supported",
              "only support NativeData and HeapData")
            true
        }
      case n: NativeData =>
        to match {
          case hh: HeapData => true
          case nn: NativeData =>
            // we will skip the S8 to U8 reorder
            val doNotReorderIt = n.layout == nn.layout && (
              n.dataType == nn.dataType || // the same data type
                (n.dataType == DataType.S8 && nn.dataType == DataType.U8) || // skip the u8 -> s8
                (n.dataType == DataType.U8 && nn.dataType == DataType.S8)) // skip the s8->u8

            !doNotReorderIt
          case _ =>
            Log4Error.invalidInputError(false, s"memory format $to is not supported",
              "only support NativeData and HeapData")
            true
        }
      case _ =>
        Log4Error.invalidInputError(false, s"memory format $from is not supported",
          "only support NativeData and HeapData")
        true
    }
  }

  def release(): Unit = { }
}
